import os
from datasets import Dataset
from trl import DataCollatorForCompletionOnlyLM, SFTConfig, SFTTrainer
import torch
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
import gc
from sklearn.metrics import balanced_accuracy_score, precision_recall_fscore_support, accuracy_score
from sklearn.metrics import f1_score as f1
from sklearn.metrics import matthews_corrcoef as mcc
from peft import (
    LoraConfig,
    PeftModel,
    PeftConfig,
    TaskType,
    get_peft_model
)
from transformers import (
    pipeline,
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
    DataCollatorForSeq2Seq,
)
import warnings
from transformers import logging
logging.set_verbosity_error()
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
pd.options.mode.chained_assignment = None 
logging.set_verbosity_error()

llama_key = open('./llama_key.txt', 'r').read()
covid = pd.read_csv('./data/covid_tweets_labeled.csv', encoding = "ISO-8859-1")

model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"

prompt_template = """You are a classifier that determines if a tweet is minimizing the threat of COVID-19.

Tweets that minimize the threat of COVID-19 do one or more of the following:
- Question how dangerous the virus is or claim that death totals related to the virus are inflated or fake.
- Make statements that are anti-vaccination or claim that the COVID-19 vaccine is ineffective or dangerous.
- Criticize the use of masks or mask mandates.
- Criticize social distancing guidelines, lockdowns, or stay-at-home orders.
- Insinuate that COVID-19 is a conspiracy or a tool to control people.

I will show you a tweet. Read the tweet and then determine if the tweet is minimizing the threat of COVID-19. Here is the tweet:

{}

If the tweet is minimizing the threat of COVID-19, return 1. If it is not minimizing the threat of COVID-19, return 0. Do not explain your answer, and only return the number."""

def format_example(example):
    tweet = example["tweet"]
    label = example["label"]
    # Create the full training prompt: prompt + expected output.
    full_prompt = prompt_template.format(tweet)
    chat = [
        {'role':'user', 'content': full_prompt},
        {'role':'assistant', 'content':str(label)}
    ]
    return chat

# Uncomment this section if you want to re-train the adapters. Takes ~10 hours on an RTX 3090
# ###################################
# ## Parameter Efficient Fine Tuning
# ###################################
# # Train 30 different adapters with 30 different seeds
# for seed in range(0,30):
#     # Random sample of 25 tweets for training
#     train_samps = covid.sample(25, random_state = seed)
#     train_examples = [{'tweet': train_samps.loc[i, 'text'], 'label': str(train_samps.loc[i, 'non_comp'])} for i in train_samps.index]
    
#     # Convert the list of examples into a Hugging Face Dataset.
#     raw_dataset = [format_example(ex) for ex in train_examples]

#     # initialize tokenizer and model
#     tokenizer = AutoTokenizer.from_pretrained(
#         model_name,
#         use_fast=True,
#         trust_remote_code=True,
#     )
#     tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
#     tokenizer.padding_side = "right"
#     model = AutoModelForCausalLM.from_pretrained(
#         model_name,
#         trust_remote_code=True,
#         device_map="cuda",
#     )
    
#     # Apply chat template and convert to dataset
#     raw_dataset = tokenizer.apply_chat_template(raw_dataset, tokenize = False)
#     train_samps['text'] = raw_dataset
#     dataset = Dataset.from_pandas(train_samps)
    
#     # This section makes it so that the loss is computed only on the generated text, rather than everything in the prompt.
#     response_template = "<|end_header_id|>"
#     collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer = tokenizer)
#     examples = list(dataset["text"])
#     encodings = [tokenizer(e) for e in examples]
#     dataloader = DataLoader(encodings, collate_fn = collator, batch_size = 1)
    
#     # Configure LoRA
#     lora_config = LoraConfig(
#         task_type=TaskType.CAUSAL_LM,
#         inference_mode=False,
#         r=32, # The rank of the adapter matrix. Higher rank matrices require more memory and adapt the model more
#         lora_alpha= 8, # The scaling factor for the adapter. Higher values means the adapter has more influence relative to frozen parameters.
#         # Try reducing alpha first if the model's output turns degenerative.
#         lora_dropout=0.05, # regularization
#         target_modules=["q_proj", "v_proj"], # Targets the query and value matrices, a common default.
#     )
    
#     model = get_peft_model(model, lora_config)
    
#     # tokenize the dataset
#     def tokenize_function(example):
#         # Tokenize the text; you can adjust max_length and padding as required.
#         return tokenizer(example["text"], truncation=True, padding="max_length", max_length=512)
#     tokenized_dataset = dataset.map(tokenize_function, batched=True)

#     # configure the trainer
#     sft_config = SFTConfig(
#         output_dir = r'./llama/llama_ft/seed_{}'.format(seed),
#         dataset_text_field = 'text',
#         max_seq_length = 512,
#         num_train_epochs = 5,
#         per_device_train_batch_size = 1,
#         per_device_eval_batch_size = 1,
#         gradient_accumulation_steps = 2,
#         eval_strategy = 'no',
#         learning_rate = 1e-4,
#         bf16 = True,
#         save_total_limit=1,
#         lr_scheduler_type='constant',
#         dataset_kwargs = {
#             "add_special_tokens": False,
#             "append_concat_token": False,
#         },
#         seed = seed,
#     )

#     # Initialize trainer
#     trainer = SFTTrainer(
#         model = model,
#         args=sft_config,
#         train_dataset=dataset,
#         processing_class=tokenizer,
#         data_collator=collator,
#         )

#     # Train
#     print("Training seed: " + str(seed))    
#     trainer.train()

#     # Clear memory for next iteration of the loop
#     del model
#     del trainer
#     gc.collect()
#     torch.cuda.empty_cache()

#############
## Inference
############
# Loop through each adapter and label docs

for seed in range(0,30):
    print("Classifying with few-shot adapter " + str(seed))  
    adapter_model_id = r"./llama/llama_ft/seed_{}/checkpoint-60".format(seed)
    peft_config = PeftConfig.from_pretrained(adapter_model_id)
    peft_config.init_lora_weights = False
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    pipe = pipeline(
        "text-generation",
        model=model_name,
        tokenizer = tokenizer,
        model_kwargs={"torch_dtype": torch.bfloat16},
        device_map="cuda",
        token = llama_key
    )
    
    pipe.model.add_adapter(peft_config)
    pipe.model.enable_adapters()
    
    res = []
    for doc in covid['text']:
        messages = [
            {"role": "user", "content": prompt_template.format(doc)},
        ]
        prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        outputs = pipe(prompt, max_new_tokens=2, do_sample=False, return_full_text = False, pad_token_id=pipe.tokenizer.eos_token_id, temperature = 0)
        res.extend(outputs)
    
    res = [text['generated_text'] for text in res]
    labs = [num[0] for num in res]
    covid[r'llama_peft{}'.format(seed)] = [1 if '1' in text else 0 for text in labs]
    
    del pipe
    gc.collect()
    torch.cuda.empty_cache()

# Compile a list of all f1 scores
scores = []
for seed in range(0,30):
    scores.append(f1(covid['non_comp'], covid[r'llama_peft{}'.format(seed)]))
# Get the mean value
np.mean(scores)
# Export results
covid.to_csv('./data/covid_llama_peft_res.csv', index = False)